{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "00de35c6",
   "metadata": {},
   "source": [
    "# Demo for Toward Universal Text-to-Music Retrieval\n",
    "\n",
    "- arXiv: https://arxiv.org/abs/2211.14558\n",
    "- pretrained model: https://zenodo.org/record/7322135\n",
    "- github repo: https://github.com/seungheondoh/music-text-representation\n",
    "- demo site: https://seungheondoh.github.io/text-music-representation-demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "22a3aee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import IPython.display as ipd\n",
    "from IPython.display import Audio, HTML\n",
    "\n",
    "import argparse\n",
    "from mtr.utils.demo_utils import get_model\n",
    "from mtr.utils.eval_utils import _text_representation\n",
    "import warnings\n",
    "warnings.filterwarnings(action='ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0932134e",
   "metadata": {},
   "source": [
    "## Load Pretrained Model & Metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e2782b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check https://github.com/seungheondoh/msd-subsets\n",
    "your_msd_path = \"\"\n",
    "msd_path = os.path.join(your_msd_path, \"msd-subsets/dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1ea3bbc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "global msd_to_id\n",
    "global id_to_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "45e6eef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "msd_to_id = pickle.load(open(os.path.join(msd_path, \"lastfm_annotation\", \"MSD_id_to_7D_id.pkl\"), 'rb'))\n",
    "id_to_path = pickle.load(open(os.path.join(msd_path, \"lastfm_annotation\", \"7D_id_to_path.pkl\"), 'rb'))\n",
    "annotation = json.load(open(os.path.join(msd_path, \"ecals_annotation/annotation.json\"), 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bf981c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pre_extract_audio_embedding(framework, text_type, text_rep):\n",
    "    ecals_test = torch.load(f\"../mtr/{framework}/exp/transformer_cnn_cf_mel/{text_type}_{text_rep}/audio_embs.pt\")\n",
    "    msdid = [k for k in ecals_test.keys()]\n",
    "    audio_embs = [ecals_test[k] for k in msdid]\n",
    "    audio_embs = torch.stack(audio_embs)\n",
    "    return audio_embs, msdid\n",
    "\n",
    "def model_load(framework, text_type, text_rep):\n",
    "    audio_embs, msdid = pre_extract_audio_embedding(framework, text_type, text_rep)\n",
    "    model, tokenizer, config = get_model(framework=framework, text_type=text_type, text_rep=text_rep)\n",
    "    return model, audio_embs, tokenizer, msdid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2023f8f3",
   "metadata": {},
   "source": [
    "## Query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2e3465dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieval_fn(query, tokenizer, model, audio_embs, msdid, annotation):\n",
    "    text_input = tokenizer(query, return_tensors=\"pt\")['input_ids']\n",
    "    with torch.no_grad():\n",
    "        text_embs = model.encode_bert_text(text_input, None)\n",
    "    audio_embs = nn.functional.normalize(audio_embs, dim=1)\n",
    "    text_embs = nn.functional.normalize(text_embs, dim=1)\n",
    "    logits = text_embs @ audio_embs.T\n",
    "    ret_item = pd.Series(logits.squeeze(0).numpy(), index=msdid)\n",
    "    instance = {}\n",
    "    metadata = {}\n",
    "    for idx, _id in enumerate(ret_item.sort_values(ascending=False).head(3).index):\n",
    "        music_obj = ipd.Audio(os.path.join(msd_path, 'songs', id_to_path[msd_to_id[_id]]) , rate=22050)\n",
    "        meta_obj = annotation[_id]\n",
    "        metadata[f'top{idx+1} music'] = meta_obj['tag']\n",
    "        music_src = music_obj.src_attr()\n",
    "        instance[f'top{idx+1} music'] = f\"\"\"<audio controls><source src=\"{music_src}\" type=\"audio/wav\"></audio></td>\"\"\"\n",
    "    return instance, metadata\n",
    "\n",
    "def retrieval_show(framework, text_type, text_rep, annotation, query, is_audio=True):    \n",
    "    model, audio_embs, tokenizer, msdid = model_load(framework, text_type, text_rep)\n",
    "    meta_results, retrieval_results = [], []\n",
    "    for i in query:\n",
    "        instance, metadata = retrieval_fn(i, tokenizer, model, audio_embs, msdid, annotation)\n",
    "        retrieval_results.append(instance)\n",
    "        meta_results.append(metadata)\n",
    "    if is_audio:\n",
    "        inference = pd.DataFrame(retrieval_results, index=query)\n",
    "        html = inference.to_html(escape=False)\n",
    "    else:\n",
    "        inference = pd.DataFrame(meta_results, index=query)\n",
    "        html = inference.to_html(escape=False)\n",
    "    ipd.display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4bada436",
   "metadata": {},
   "outputs": [],
   "source": [
    "tag_query = \"banjo\"\n",
    "caption_query = \"fusion jazz with synth, bass, drums, saxophone\"\n",
    "unseen_query = \"music for meditation or listen to in the forest\"\n",
    "query = [tag_query, caption_query, unseen_query]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "134e6c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "framework='contrastive' # triplet\n",
    "text_type='bert' # tag, caption\n",
    "text_rep=\"stochastic\"\n",
    "# retrieval_show(framework, text_type, text_rep, annotation, query, is_audio=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8c48b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
